import torch.nn as nn
import torch

"""
    定义LSTM的参数含义：(input_size, hidden_size, num_layers)
    定义输入张量的参数含义：(sequence_length, batch_size, input_size)
    定义隐藏层初始张量和细胞状态张量的参数含义：(number_layers*num_directions, batch_size, hidden_size)
"""
lstm = nn.LSTM(5, 6, 2)
input_ = torch.randn(1, 3, 5)
h0 = torch.randn(2, 3, 6)
c0 = torch.randn(2, 3, 6)
output, (hn, cn) = lstm(input_, (h0, c0))
print(output)
print(output.shape)
print(hn)
print(hn.shape)
print(cn)
print(cn.shape)
